From 60d928447873d40593848aafcbc0b629070890b1 Mon Sep 17 00:00:00 2001 From: SofieAastrup Date: Mon, 13 May 2024 00:41:13 +0200 Subject: [PATCH] adds image splitting and guide image loading --- painter/src/create_project.py | 115 ++++++++++++++++++++++++++++++++++ painter/src/split_images.py | 39 ++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 painter/src/split_images.py diff --git a/painter/src/create_project.py b/painter/src/create_project.py index e39e4dc..f2784ad 100644 --- a/painter/src/create_project.py +++ b/painter/src/create_project.py @@ -28,6 +28,7 @@ from im_utils import is_image from palette import PaletteEditWidget from name_edit_widget import NameEditWidget +from split_images import split_single_image class CreateProjectWidget(QtWidgets.QWidget): @@ -40,6 +41,8 @@ def __init__(self, sync_dir): self.selected_model = None self.use_random_weights = True self.sync_dir = sync_dir + self.guide_image = None + self.patch_image = False self.initUI() def initUI(self): @@ -50,7 +53,9 @@ def initUI(self): self.layout.addWidget(self.name_edit_widget) self.add_im_dir_widget() + self.add_guide_im_dir_widget() self.add_radio_widget() + self.add_patch_widget() self.add_model_btn() self.add_palette_widget() @@ -76,6 +81,16 @@ def add_im_dir_widget(self): specify_image_dir_btn.clicked.connect(self.select_photo_dir) self.layout.addWidget(specify_image_dir_btn) + def add_guide_im_dir_widget(self): + guide_directory_label = QtWidgets.QLabel() + guide_directory_label.setText("Guide image directory: Not yet specified") + self.layout.addWidget(guide_directory_label) + self.guide_directory_label = guide_directory_label + + specify_image_dir_btn = QtWidgets.QPushButton('Specify guide image directory (optional)') + specify_image_dir_btn.clicked.connect(self.select_guide_im_dir) + self.layout.addWidget(specify_image_dir_btn) + def add_radio_widget(self): radio_widget = QtWidgets.QWidget() radio_layout = QtWidgets.QHBoxLayout() @@ -93,6 +108,20 @@ def add_radio_widget(self): radio.name = "specify" radio.toggled.connect(self.on_radio_clicked) radio_layout.addWidget(radio) + + def add_patch_widget(self): + patch_widget = QtWidgets.QWidget() + patch_layout = QtWidgets.QHBoxLayout() + patch_widget.setLayout(patch_layout) + self.layout.addWidget(patch_widget) + + patch = QtWidgets.QCheckBox("Turn image into patches") + patch.setChecked(False) + patch.name = "patched" + patch.toggled.connect(self.on_patch_cliked) + patch_layout.addWidget(patch) + + def add_model_btn(self): model_label = QtWidgets.QLabel() @@ -129,6 +158,15 @@ def on_radio_clicked(self): self.specify_model_btn.setVisible(specify) self.use_random_weights = not specify self.validate() + + def on_patch_cliked(self): + patch = self.sender() + if patch.isChecked(): + self.patch_image = True + else: + self.patch_image = False + self.validate() + def validate(self): self.proj_name = self.name_edit_widget.name @@ -155,6 +193,17 @@ def validate(self): self.info_label.setText(message) self.create_project_btn.setEnabled(False) return + + if not self.guide_image == None: + cur_guides = os.listdir(self.guide_image) + cur_guides = [f for f in cur_guides if is_image(f)] + if not cur_guides: + message = "Folder of guide images contains no compatible images. Valid formats include NIfTI (.nii.gz) and nrrd. \nYou can continue without guide images" + self.info_label.setText(message) + # self.create_project_btn.setEnabled(False) + self.guide_image = None + return + if len(self.palette_edit_widget.get_brush_data()) < 1: self.info_label.setText(f"At least one foreground brush must be specified") @@ -179,6 +228,17 @@ def output_selected(): self.photo_dialog.fileSelected.connect(output_selected) self.photo_dialog.open() + + def select_guide_im_dir(self): + self.guide_dialog = QtWidgets.QFileDialog(self) + self.guide_dialog.setFileMode(QtWidgets.QFileDialog.Directory) + def output_selected(): + self.guide_image = self.guide_dialog.selectedFiles()[0] + self.guide_directory_label.setText('Guide Image directory: ' + self.guide_image) + self.validate() + + self.guide_dialog.fileSelected.connect(output_selected) + self.guide_dialog.open() def select_model(self): options = QtWidgets.QFileDialog.Options() @@ -208,6 +268,17 @@ def create_project(self): f"{datasets_dir}.") QtWidgets.QMessageBox.about(self, 'Project Creation Error', message) return + + if self.guide_image: + guides_path = os.path.abspath(self.guide_image) + + if not guides_path.startswith(datasets_dir): + message = ("When creating a project the selected dataset guides must " + "be in the datasets folder. The selected dataset is " + f"{guides_path} and the datasets folder is " + f"{datasets_dir}.") + QtWidgets.QMessageBox.about(self, 'Project Creation Error', message) + return os.makedirs(self.sync_dir / project_location) proj_file_path = (self.sync_dir / project_location / @@ -267,6 +338,50 @@ def create_project(self): 'location': str(PurePosixPath(project_location)) } + if self.guide_image: + project_info['guide_image_dir'] = self.guide_image + + if self.patch_image: + if len(all_fnames) == 1: + "ENTER SPLITTING" + dataset = os.path.join(dataset_path, all_fnames[0]) + split_single_image(dataset, dataset.replace(".nii.gz", '')) + print(f"{dataset=}") + dataset = dataset.replace(".nii.gz", '') + all_fnames = os.listdir(dataset) + all_fnames.sort() + project_info['dataset'] = dataset + + if self.guide_image: + print(f"{self.guide_image=}") + all_gnames = os.listdir(self.guide_image) + #only images + all_gnames = [a for a in all_gnames if is_image(a)] + # ignore these 'hidden' files. + all_gnames = [a for a in all_gnames if a[0] != '.'] + if len(all_gnames) > 1: + message = ("Image folder contains one image, but the" + f"image folder contains {len(all_gnames)} images." + "Please only include the correct image.") + QtWidgets.QMessageBox.about(self, 'Project Creation Error', message) + return + gname = all_gnames[0] + print(f"{gname=}") + + guide_path = os.path.join(self.guide_image, gname).replace(".nii.gz", '') + split_single_image(guide_path+".nii.gz", guide_path) + project_info['guide_image_dir'] = guide_path + print(f"{guide_path=}") + all_guides = os.listdir(guide_path) + all_guides.sort() + if all_fnames != all_guides: + message = ("guide image patching failded because the dimensions of" + "of the guide image did not match the actual image") + QtWidgets.QMessageBox.about(self, 'Project Creation Error', message) + return + + + # only add classes info if the palette is defined. # otherwise the server will default to single class (fg/bg) if hasattr(self, 'palette_edit_widget'): diff --git a/painter/src/split_images.py b/painter/src/split_images.py new file mode 100644 index 0000000..de55911 --- /dev/null +++ b/painter/src/split_images.py @@ -0,0 +1,39 @@ +# from matplotlib import pyplot as plt +import numpy as np +import nibabel as nib +import sys +import os + +def split_single_image(fname, save_folder, prefix=""): + if (not os.path.isdir(save_folder)): + os.mkdir(save_folder) + full_img = nib.load(fname) + full_img = np.array(full_img.dataobj) + ax1, ax2, ax3 = full_img.shape + len1 = ax1//4 + len2 = ax2//4 + len3 = ax3//2 + print(f"Creating image patches from {fname}. Creating the following files:") + for i in range(ax1//len1): + for j in range(ax2//len2): + for k in range(ax3//len3): + ax1min = i*len1 + ax1max = (i+1)*len1 + ax2min = j*len2 + ax2max = (j+1)*len2 + ax3min = k*len3 + ax3max = (k+1)*len3 + filename = f"{prefix}ax1({ax1min}-{ax1max})ax2({ax2min}-{ax2max})ax3({ax3min}-{ax3max})" + img = full_img[ax1min:ax1max, ax2min:ax2max, ax3min:ax3max] + img = nib.Nifti1Image(img, np.eye(4)) + out_path = f"{save_folder}/{filename}.nii.gz" + print(out_path) + nib.save(img, out_path) + +def split_images(image_folder, save_folder): + if (not os.path.isdir(save_folder)): + os.mkdir(save_folder) + fnames = os.listdir(image_folder) + for f in fnames: + split_single_image(f"{image_folder}/{f}", f"{save_folder}/{f[:-7]}") +