Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions painter/src/create_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from im_utils import is_image
from palette import PaletteEditWidget
from name_edit_widget import NameEditWidget
from split_images import split_single_image

class CreateProjectWidget(QtWidgets.QWidget):

Expand All @@ -40,6 +41,8 @@ def __init__(self, sync_dir):
self.selected_model = None
self.use_random_weights = True
self.sync_dir = sync_dir
self.guide_image = None
self.patch_image = False
self.initUI()

def initUI(self):
Expand All @@ -50,7 +53,9 @@ def initUI(self):
self.layout.addWidget(self.name_edit_widget)

self.add_im_dir_widget()
self.add_guide_im_dir_widget()
self.add_radio_widget()
self.add_patch_widget()
self.add_model_btn()
self.add_palette_widget()

Expand All @@ -76,6 +81,16 @@ def add_im_dir_widget(self):
specify_image_dir_btn.clicked.connect(self.select_photo_dir)
self.layout.addWidget(specify_image_dir_btn)

def add_guide_im_dir_widget(self):
guide_directory_label = QtWidgets.QLabel()
guide_directory_label.setText("Guide image directory: Not yet specified")
self.layout.addWidget(guide_directory_label)
self.guide_directory_label = guide_directory_label

specify_image_dir_btn = QtWidgets.QPushButton('Specify guide image directory (optional)')
specify_image_dir_btn.clicked.connect(self.select_guide_im_dir)
self.layout.addWidget(specify_image_dir_btn)

def add_radio_widget(self):
radio_widget = QtWidgets.QWidget()
radio_layout = QtWidgets.QHBoxLayout()
Expand All @@ -93,6 +108,20 @@ def add_radio_widget(self):
radio.name = "specify"
radio.toggled.connect(self.on_radio_clicked)
radio_layout.addWidget(radio)

def add_patch_widget(self):
patch_widget = QtWidgets.QWidget()
patch_layout = QtWidgets.QHBoxLayout()
patch_widget.setLayout(patch_layout)
self.layout.addWidget(patch_widget)

patch = QtWidgets.QCheckBox("Turn image into patches")
patch.setChecked(False)
patch.name = "patched"
patch.toggled.connect(self.on_patch_cliked)
patch_layout.addWidget(patch)



def add_model_btn(self):
model_label = QtWidgets.QLabel()
Expand Down Expand Up @@ -129,6 +158,15 @@ def on_radio_clicked(self):
self.specify_model_btn.setVisible(specify)
self.use_random_weights = not specify
self.validate()

def on_patch_cliked(self):
patch = self.sender()
if patch.isChecked():
self.patch_image = True
else:
self.patch_image = False
self.validate()


def validate(self):
self.proj_name = self.name_edit_widget.name
Expand All @@ -155,6 +193,17 @@ def validate(self):
self.info_label.setText(message)
self.create_project_btn.setEnabled(False)
return

if not self.guide_image == None:
cur_guides = os.listdir(self.guide_image)
cur_guides = [f for f in cur_guides if is_image(f)]
if not cur_guides:
message = "Folder of guide images contains no compatible images. Valid formats include NIfTI (.nii.gz) and nrrd. \nYou can continue without guide images"
self.info_label.setText(message)
# self.create_project_btn.setEnabled(False)
self.guide_image = None
return


if len(self.palette_edit_widget.get_brush_data()) < 1:
self.info_label.setText(f"At least one foreground brush must be specified")
Expand All @@ -179,6 +228,17 @@ def output_selected():

self.photo_dialog.fileSelected.connect(output_selected)
self.photo_dialog.open()

def select_guide_im_dir(self):
self.guide_dialog = QtWidgets.QFileDialog(self)
self.guide_dialog.setFileMode(QtWidgets.QFileDialog.Directory)
def output_selected():
self.guide_image = self.guide_dialog.selectedFiles()[0]
self.guide_directory_label.setText('Guide Image directory: ' + self.guide_image)
self.validate()

self.guide_dialog.fileSelected.connect(output_selected)
self.guide_dialog.open()

def select_model(self):
options = QtWidgets.QFileDialog.Options()
Expand Down Expand Up @@ -208,6 +268,17 @@ def create_project(self):
f"{datasets_dir}.")
QtWidgets.QMessageBox.about(self, 'Project Creation Error', message)
return

if self.guide_image:
guides_path = os.path.abspath(self.guide_image)

if not guides_path.startswith(datasets_dir):
message = ("When creating a project the selected dataset guides must "
"be in the datasets folder. The selected dataset is "
f"{guides_path} and the datasets folder is "
f"{datasets_dir}.")
QtWidgets.QMessageBox.about(self, 'Project Creation Error', message)
return

os.makedirs(self.sync_dir / project_location)
proj_file_path = (self.sync_dir / project_location /
Expand Down Expand Up @@ -267,6 +338,50 @@ def create_project(self):
'location': str(PurePosixPath(project_location))
}

if self.guide_image:
project_info['guide_image_dir'] = self.guide_image

if self.patch_image:
if len(all_fnames) == 1:
"ENTER SPLITTING"
dataset = os.path.join(dataset_path, all_fnames[0])
split_single_image(dataset, dataset.replace(".nii.gz", ''))
print(f"{dataset=}")
dataset = dataset.replace(".nii.gz", '')
all_fnames = os.listdir(dataset)
all_fnames.sort()
project_info['dataset'] = dataset

if self.guide_image:
print(f"{self.guide_image=}")
all_gnames = os.listdir(self.guide_image)
#only images
all_gnames = [a for a in all_gnames if is_image(a)]
# ignore these 'hidden' files.
all_gnames = [a for a in all_gnames if a[0] != '.']
if len(all_gnames) > 1:
message = ("Image folder contains one image, but the"
f"image folder contains {len(all_gnames)} images."
"Please only include the correct image.")
QtWidgets.QMessageBox.about(self, 'Project Creation Error', message)
return
gname = all_gnames[0]
print(f"{gname=}")

guide_path = os.path.join(self.guide_image, gname).replace(".nii.gz", '')
split_single_image(guide_path+".nii.gz", guide_path)
project_info['guide_image_dir'] = guide_path
print(f"{guide_path=}")
all_guides = os.listdir(guide_path)
all_guides.sort()
if all_fnames != all_guides:
message = ("guide image patching failded because the dimensions of"
"of the guide image did not match the actual image")
QtWidgets.QMessageBox.about(self, 'Project Creation Error', message)
return



# only add classes info if the palette is defined.
# otherwise the server will default to single class (fg/bg)
if hasattr(self, 'palette_edit_widget'):
Expand Down
39 changes: 39 additions & 0 deletions painter/src/split_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# from matplotlib import pyplot as plt
import numpy as np
import nibabel as nib
import sys
import os

def split_single_image(fname, save_folder, prefix=""):
if (not os.path.isdir(save_folder)):
os.mkdir(save_folder)
full_img = nib.load(fname)
full_img = np.array(full_img.dataobj)
ax1, ax2, ax3 = full_img.shape
len1 = ax1//4
len2 = ax2//4
len3 = ax3//2
print(f"Creating image patches from {fname}. Creating the following files:")
for i in range(ax1//len1):
for j in range(ax2//len2):
for k in range(ax3//len3):
ax1min = i*len1
ax1max = (i+1)*len1
ax2min = j*len2
ax2max = (j+1)*len2
ax3min = k*len3
ax3max = (k+1)*len3
filename = f"{prefix}ax1({ax1min}-{ax1max})ax2({ax2min}-{ax2max})ax3({ax3min}-{ax3max})"
img = full_img[ax1min:ax1max, ax2min:ax2max, ax3min:ax3max]
img = nib.Nifti1Image(img, np.eye(4))
out_path = f"{save_folder}/{filename}.nii.gz"
print(out_path)
nib.save(img, out_path)

def split_images(image_folder, save_folder):
if (not os.path.isdir(save_folder)):
os.mkdir(save_folder)
fnames = os.listdir(image_folder)
for f in fnames:
split_single_image(f"{image_folder}/{f}", f"{save_folder}/{f[:-7]}")