Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 172 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import numpy as np
from PyQt5 import QtWidgets
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QPixmap, QIntValidator, QKeySequence
from PyQt5.QtGui import QPixmap, QIntValidator, QKeySequence, QPainter, QPalette
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QCheckBox, QFileDialog, QDesktopWidget, QLineEdit, \
QRadioButton, QShortcut, QScrollArea, QVBoxLayout, QGroupBox, QFormLayout
QRadioButton, QShortcut, QScrollArea, QVBoxLayout, QGraphicsScene, QGraphicsPixmapItem, QGroupBox, QFormLayout
from xlsxwriter.workbook import Workbook


Expand All @@ -23,11 +23,20 @@ def get_img_paths(dir, extensions=('.jpg', '.png', '.jpeg')):

for filename in os.listdir(dir):
if filename.lower().endswith(extensions):
img_paths.append(os.path.join(dir, filename))
img_paths.append(join_path(dir, filename))

return img_paths


def join_path(path, *paths):
"""
Use os.path.abspath to prevent mixed slashes on Windows
"""
path = os.path.join(path, *paths)
path = os.path.abspath(path)
return path


def make_folder(directory):
"""
Make folder if it doesn't already exist
Expand Down Expand Up @@ -214,6 +223,8 @@ def pick_labels_file(self):
# fill the input fileds with loaded labels
for input, label in zip(self.label_inputs, labels):
input.setText(label)
else:
print("Invalid file")

def generate_label_inputs(self):
"""
Expand Down Expand Up @@ -250,6 +261,7 @@ def generate_label_inputs(self):
self.groupBox.setLayout(self.formLayout)
self.scroll.setWidget(self.groupBox)
self.scroll.setWidgetResizable(True)

def centerOnScreen(self):
"""
Centers the window on the screen.
Expand Down Expand Up @@ -296,6 +308,70 @@ def continue_app(self):
self.error_message.setText(message)


class ImageBox(QtWidgets.QGraphicsView):
def __init__(self, parent=None):
super().__init__(parent)
self.setAlignment(Qt.AlignLeft | Qt.AlignTop)
self.setBackgroundRole(QPalette.Background)
#self.setFrameStyle(0)
self.setRenderHints(
QPainter.Antialiasing | QPainter.SmoothPixmapTransform
)

self._pixmap_item = QGraphicsPixmapItem()
self._pixmap_item.setTransformationMode(Qt.SmoothTransformation)

scene = QGraphicsScene()
scene.addItem(self._pixmap_item)
self.setScene(scene)

def load_pixmap(self, pixmap):
self.resetZoom()
self._pixmap_item.setPixmap(pixmap)
return True

def zoomIn(self, viewAnchor=QtWidgets.QGraphicsView.AnchorUnderMouse):
self.zoom(1.1, viewAnchor)

def zoomOut(self, viewAnchor=QtWidgets.QGraphicsView.AnchorUnderMouse):
self.zoom(1/1.1, viewAnchor)

def zoom(self, f, viewAnchor=QtWidgets.QGraphicsView.AnchorUnderMouse):
self.setTransformationAnchor(viewAnchor)
self.scale(f, f)
self.__setDragEnabled(self.__isEnableDrag())
self.setTransformationAnchor(self.AnchorUnderMouse)

def resetZoom(self):
self.resetTransform()
self.__setDragEnabled(False)

def fitToWindow(self):
self.fitInView(self.sceneRect(), Qt.KeepAspectRatio)

def __isEnableDrag(self):
v = self.verticalScrollBar().maximum() > 0
h = self.horizontalScrollBar().maximum() > 0
return v or h

def __setDragEnabled(self, isEnabled):
self.setDragMode(
self.ScrollHandDrag if isEnabled else self.NoDrag)

def wheelEvent(self, event):
mods = event.modifiers()
delta = event.angleDelta()

# If user presses ctrl and middle mouse scroll
if Qt.ControlModifier == int(mods):
if int(delta.y())>0:
self.zoomIn()
else:
self.zoomOut()
else:
super().wheelEvent(event)


class LabelerWindow(QWidget):
def __init__(self, labels, input_folder, mode):
super().__init__()
Expand All @@ -315,7 +391,6 @@ def __init__(self, labels, input_folder, mode):
self.input_folder = input_folder
self.img_paths = get_img_paths(input_folder)
self.labels = labels
self.num_labels = len(self.labels)
self.num_images = len(self.img_paths)
self.assigned_labels = {}
self.mode = mode
Expand All @@ -324,7 +399,7 @@ def __init__(self, labels, input_folder, mode):
self.label_buttons = []

# Initialize Labels
self.image_box = QLabel(self)
self.image_box = ImageBox(self)
self.img_name_label = QLabel(self)
self.progress_bar = QLabel(self)
self.curr_image_headline = QLabel('Current image', self)
Expand Down Expand Up @@ -363,6 +438,10 @@ def init_ui(self):

# image name label
self.img_name_label.setGeometry(20, 40, self.img_panel_width, 20)
self.img_name_label.setCursor(Qt.IBeamCursor)
self.img_name_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.img_name_label.setContextMenuPolicy(Qt.CustomContextMenu)
self.img_name_label.customContextMenuRequested.connect(self.show_img_name_label_menu)

# progress bar (how many images have I labeled so far)
self.progress_bar.setGeometry(20, 65, self.img_panel_width, 20)
Expand All @@ -378,6 +457,9 @@ def init_ui(self):
self.set_image(self.img_paths[0])
self.image_box.setGeometry(20, 120, self.img_panel_width, self.img_panel_height)
self.image_box.setAlignment(Qt.AlignTop)
self.image_box.setFixedSize(self.img_panel_width, self.img_panel_height)
self.image_box.setSceneRect(0, 0, self.img_panel_width - 2*self.image_box.frameWidth(),
self.img_panel_height - 2*self.image_box.frameWidth())

# image name
self.img_name_label.setText(self.img_paths[self.counter])
Expand Down Expand Up @@ -423,28 +505,51 @@ def init_buttons(self):
next_im_btn.clicked.connect(lambda state, filename='assigned_classes': self.generate_csv(filename))
next_im_btn.setObjectName("blueButton")

# Add "Add label" button
add_label_btn = QtWidgets.QPushButton("Add label", self)
add_label_btn.move(self.img_panel_width + 20, 560)
add_label_btn.clicked.connect(self.add_label)
add_label_btn.setObjectName("blueButton")

self.newLabelInput = QLineEdit(self)
self.newLabelInput.move(self.img_panel_width + 20, 540)

# Create button for each label
x_shift = 0 # variable that helps to compute x-coordinate of button in UI
for i, label in enumerate(self.labels):
self.label_buttons.append(QtWidgets.QPushButton(label, self))
button = self.label_buttons[i]
button = self.init_label_button(i, label)
self.label_buttons.append(button)

def init_label_button(self, i, label):
button = QtWidgets.QPushButton(label, self)

# create click event (set label)
# https://stackoverflow.com/questions/35819538/using-lambda-expression-to-connect-slots-in-pyqt
button.clicked.connect(lambda state, x=label: self.set_label(x))

# create click event (set label)
# https://stackoverflow.com/questions/35819538/using-lambda-expression-to-connect-slots-in-pyqt
button.clicked.connect(lambda state, x=label: self.set_label(x))
# create keyboard shortcut event (set label)
# shortcuts start getting overwritten when number of labels >9
label_kbs = QShortcut(QKeySequence(f"{i+1 % 10}"), self)
label_kbs.activated.connect(lambda x=label: self.set_label(x))

# create keyboard shortcut event (set label)
# shortcuts start getting overwritten when number of labels >9
label_kbs = QShortcut(QKeySequence(f"{i+1 % 10}"), self)
label_kbs.activated.connect(lambda x=label: self.set_label(x))
# place button in GUI (create multiple columns if there is more than 10 button)
y_shift = (30 + 10) * (i % 10)
x_shift = 120 * (i // 10)

# place button in GUI (create multiple columns if there is more than 10 button)
y_shift = (30 + 10) * (i % 10)
if (i != 0 and i % 10 == 0):
x_shift += 120
y_shift = 0
button.move(self.img_panel_width + 20 + x_shift, y_shift + 120)

button.move(self.img_panel_width + 20 + x_shift, y_shift + 120)
return button

def add_label(self):
new_label = self.newLabelInput.text().strip()
if new_label != '' and new_label not in self.labels:
self.newLabelInput.setText('')
self.labels.append(new_label)
button = self.init_label_button(len(self.labels)-1, new_label)
button.show()
self.label_buttons.append(button)

if self.mode == 'copy' or self.mode == 'move':
self.create_label_folders([new_label], self.input_folder)

def set_label(self, label):
"""
Expand All @@ -469,25 +574,25 @@ def set_label(self, label):

# remove image from appropriate folder
if self.mode == 'copy':
os.remove(os.path.join(self.input_folder, label, img_name))
os.remove(join_path(self.input_folder, label, img_name))

elif self.mode == 'move':
# label was in assigned labels, so I want to remove it from label folder,
# but this was the last label, so move the image to input folder.
# Don't remove it, because it it not save anywehre else
if img_name not in self.assigned_labels.keys():
shutil.move(os.path.join(self.input_folder, label, img_name), self.input_folder)
shutil.move(join_path(self.input_folder, label, img_name), self.input_folder)
else:
# label was in assigned labels and the image is store in another label folder,
# so I want to remove it from current label folder
os.remove(os.path.join(self.input_folder, label, img_name))
os.remove(join_path(self.input_folder, label, img_name))

# label is not there yet. But the image has some labels already
else:
self.assigned_labels[img_name].append(label)

# path to copy/move images
copy_to = os.path.join(self.input_folder, label)
copy_to = join_path(self.input_folder, label)

# copy/move the image into appropriate label folder
if self.mode == 'copy':
Expand All @@ -497,15 +602,15 @@ def set_label(self, label):
elif self.mode == 'move':
# the image doesn't have to be stored in input_folder anymore.
# get the path where the image is stored
copy_from = os.path.join(self.input_folder, self.assigned_labels[img_name][0], img_name)
copy_from = join_path(self.input_folder, self.assigned_labels[img_name][0], img_name)
shutil.copy(copy_from, copy_to)

else:
# Image has no labels yet. Set new label and copy/move

self.assigned_labels[img_name] = [label]
# move copy images to appropriate directories
copy_to = os.path.join(self.input_folder, label)
copy_to = join_path(self.input_folder, label)

if self.mode == 'copy':
shutil.copy(img_path, copy_to)
Expand All @@ -531,15 +636,14 @@ def show_next_image(self):
# If we have already assigned label to this image and mode is 'move', change the input path.
# The reason is that the image was moved from '.../input_folder' to '.../input_folder/label'
if self.mode == 'move' and filename in self.assigned_labels.keys():
path = os.path.join(self.input_folder, self.assigned_labels[filename][0], filename)
path = join_path(self.input_folder, self.assigned_labels[filename][0], filename)

self.set_image(path)
self.img_name_label.setText(path)
self.progress_bar.setText(f'image {self.counter + 1} of {self.num_images}')
self.set_button_color(filename)
self.csv_generated_message.setText('')


# change button color if this is last image in dataset
elif self.counter == self.num_images - 1:
path = self.img_paths[self.counter]
Expand All @@ -559,7 +663,7 @@ def show_prev_image(self):
# If we have already assigned label to this image and mode is 'move', change the input path.
# The reason is that the image was moved from '.../input_folder' to '.../input_folder/label'
if self.mode == 'move' and filename in self.assigned_labels.keys():
path = os.path.join(self.input_folder, self.assigned_labels[filename][0], filename)
path = join_path(self.input_folder, self.assigned_labels[filename][0], filename)

self.set_image(path)
self.img_name_label.setText(path)
Expand Down Expand Up @@ -588,17 +692,17 @@ def set_image(self, path):
else:
pixmap = pixmap.scaledToHeight(self.img_panel_height - margin)

self.image_box.setPixmap(pixmap)
self.image_box.load_pixmap(pixmap)

def generate_csv(self, out_filename):
"""
Generates and saves csv file with assigned labels.
Assigned label is represented as one-hot vector.
:param out_filename: name of csv file to be generated
"""
path_to_save = os.path.join(self.input_folder, 'output')
path_to_save = join_path(self.input_folder, 'output')
make_folder(path_to_save)
csv_file_path = os.path.join(path_to_save, out_filename) + '.csv'
csv_file_path = join_path(path_to_save, out_filename) + '.csv'

with open(csv_file_path, "w", newline='') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
Expand Down Expand Up @@ -654,6 +758,39 @@ def set_button_color(self, filename):
else:
button.setStyleSheet('background-color: None')

def show_img_name_label_menu(self, pos):
"""
Display menu when right click on image name label
"""
text = self.img_name_label.selectedText()
img_path = self.img_paths[self.counter]

menu = QtWidgets.QMenu()
copy_selected_action = menu.addAction('Copy Selected')
copy_dir_path_action = menu.addAction('Copy Dir Path')
copy_filename_action = menu.addAction('Copy Filename')
copy_fullpath_action = menu.addAction('Copy Fullpath')
menu.addSeparator()
select_action = menu.addAction('Select All')

if not text:
copy_selected_action.setEnabled(False)

# show the menu
action = menu.exec_(self.img_name_label.mapToGlobal(pos))

if action == copy_selected_action:
# if the menu has been triggered by the action, copy to the clipboard
QtWidgets.QApplication.clipboard().setText(text)
elif action == copy_dir_path_action:
QtWidgets.QApplication.clipboard().setText(os.path.split(img_path)[0])
elif action == copy_filename_action:
QtWidgets.QApplication.clipboard().setText(os.path.split(img_path)[-1])
elif action == copy_fullpath_action:
QtWidgets.QApplication.clipboard().setText(img_path)
elif action == select_action:
self.img_name_label.setSelection(0, len(img_path))

def closeEvent(self, event):
"""
This function is executed when the app is closed.
Expand All @@ -674,7 +811,7 @@ def labels_to_zero_one(self, labels):
label_to_int = dict((c, i) for i, c in enumerate(self.labels))

# initialize array to save selected labels
zero_one_arr = np.zeros([self.num_labels], dtype=int)
zero_one_arr = np.zeros([len(self.labels)], dtype=int)
for label in labels:
zero_one_arr[label_to_int[label]] = 1

Expand All @@ -683,7 +820,7 @@ def labels_to_zero_one(self, labels):
@staticmethod
def create_label_folders(labels, folder):
for label in labels:
make_folder(os.path.join(folder, label))
make_folder(join_path(folder, label))


if __name__ == '__main__':
Expand Down